{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import pathlib\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "import scipy.sparse\n",
    "import scipy.io\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "save_prefix = 'data/preprocessed/LastFM_magnn_processed/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "user_artist = pd.read_csv('data/raw/LastFM_magnn/user_artist.dat', encoding='utf-8', delimiter='\\t', names=['userID', 'artistID', 'weight'])\n",
    "user_friend = pd.read_csv('data/raw/LastFM_magnn/user_user(original).dat', encoding='utf-8', delimiter='\\t', names=['userID', 'friendID'])\n",
    "artist_tag = pd.read_csv('data/raw/LastFM_magnn/artist_tag.dat', encoding='utf-8', delimiter='\\t', names=['artistID', 'tagID'])\n",
    "num_user = 1892\n",
    "num_artist = 17632\n",
    "num_tag = 11945"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "train_val_test_idx = np.load('data/raw/LastFM_magnn/train_val_test_idx.npz')\n",
    "train_idx = train_val_test_idx['train_idx']\n",
    "val_idx = train_val_test_idx['val_idx']\n",
    "test_idx = train_val_test_idx['test_idx']\n",
    " \n",
    "user_artist = user_artist.loc[train_idx].reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# build the adjacency matrix\n",
    "# 0 for user, 1 for artist, 2 for tag\n",
    "dim = num_user + num_artist + num_tag\n",
    "\n",
    "type_mask = np.zeros((dim), dtype=int)\n",
    "type_mask[num_user:num_user+num_artist] = 1\n",
    "type_mask[num_user+num_artist:] = 2\n",
    "\n",
    "adjM = np.zeros((dim, dim), dtype=int)\n",
    "for _, row in user_artist.iterrows():\n",
    "    uid = row['userID'] - 1\n",
    "    aid = num_user + row['artistID'] - 1\n",
    "    adjM[uid, aid] = max(1, row['weight'])\n",
    "    adjM[aid, uid] = max(1, row['weight'])\n",
    "for _, row in user_friend.iterrows():\n",
    "    uid = row['userID'] - 1\n",
    "    fid = row['friendID'] - 1\n",
    "    adjM[uid, fid] = 1\n",
    "for _, row in artist_tag.iterrows():\n",
    "    aid = num_user + row['artistID'] - 1\n",
    "    tid = num_user + num_artist + row['tagID'] - 1\n",
    "    adjM[aid, tid] += 1\n",
    "    adjM[tid, aid] += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# filter out artist-tag links with counts less than 2\n",
    "# adjM[num_user:num_user+num_artist, num_user+num_artist:] = adjM[num_user:num_user+num_artist, num_user+num_artist:] * (adjM[num_user:num_user+num_artist, num_user+num_artist:] > 1)\n",
    "# adjM[num_user+num_artist:, num_user:num_user+num_artist] = np.transpose(adjM[num_user:num_user+num_artist, num_user+num_artist:])\n",
    "\n",
    "# valid_tag_idx = adjM[num_user:num_user+num_artist, num_user+num_artist:].sum(axis=0).nonzero()[0]\n",
    "# num_tag = len(valid_tag_idx)\n",
    "# dim = num_user + num_artist + num_tag\n",
    "# type_mask = np.zeros((dim), dtype=int)\n",
    "# type_mask[num_user:num_user+num_artist] = 1\n",
    "# type_mask[num_user+num_artist:] = 2\n",
    "\n",
    "# adjM_reduced = np.zeros((dim, dim), dtype=int)\n",
    "# adjM_reduced[:num_user+num_artist, :num_user+num_artist] = adjM[:num_user+num_artist, :num_user+num_artist]\n",
    "# adjM_reduced[num_user:num_user+num_artist, num_user+num_artist:] = adjM[num_user:num_user+num_artist, num_user+num_artist:][:, valid_tag_idx]\n",
    "# adjM_reduced[num_user+num_artist:, num_user:num_user+num_artist] = np.transpose(adjM_reduced[num_user:num_user+num_artist, num_user+num_artist:])\n",
    "\n",
    "# adjM = adjM_reduced"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "user_artist_list = {i: adjM[i, num_user:num_user+num_artist].nonzero()[0] for i in range(num_user)}\n",
    "artist_user_list = {i: adjM[num_user + i, :num_user].nonzero()[0] for i in range(num_artist)}\n",
    "user_user_list = {i: adjM[i, :num_user].nonzero()[0] for i in range(num_user)}\n",
    "artist_tag_list = {i: adjM[num_user + i, num_user+num_artist:].nonzero()[0] for i in range(num_artist)}\n",
    "tag_artist_list = {i: adjM[num_user + num_artist + i, num_user:num_user+num_artist].nonzero()[0] for i in range(num_tag)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# 0-1-0\n",
    "u_a_u = []\n",
    "for a, u_list in artist_user_list.items():\n",
    "    u_a_u.extend([(u1, a, u2) for u1 in u_list for u2 in u_list])\n",
    "u_a_u = np.array(u_a_u)\n",
    "u_a_u[:, 1] += num_user\n",
    "sorted_index = sorted(list(range(len(u_a_u))), key=lambda i : u_a_u[i, [0, 2, 1]].tolist())\n",
    "u_a_u = u_a_u[sorted_index]\n",
    "\n",
    "# 1-2-1\n",
    "a_t_a = []\n",
    "for t, a_list in tag_artist_list.items():\n",
    "    a_t_a.extend([(a1, t, a2) for a1 in a_list for a2 in a_list])\n",
    "a_t_a = np.array(a_t_a)\n",
    "a_t_a += num_user\n",
    "a_t_a[:, 1] += num_artist\n",
    "sorted_index = sorted(list(range(len(a_t_a))), key=lambda i : a_t_a[i, [0, 2, 1]].tolist())\n",
    "a_t_a = a_t_a[sorted_index]\n",
    "\n",
    "# 0-1-2-1-0\n",
    "u_a_t_a_u = []\n",
    "for a1, t, a2 in a_t_a:\n",
    "    if len(artist_user_list[a1 - num_user]) == 0 or len(artist_user_list[a2 - num_user]) == 0:\n",
    "        continue\n",
    "    candidate_u1_list = np.random.choice(len(artist_user_list[a1 - num_user]), int(0.2 * len(artist_user_list[a1 - num_user])), replace=False)\n",
    "    candidate_u1_list = artist_user_list[a1 - num_user][candidate_u1_list]\n",
    "    candidate_u2_list = np.random.choice(len(artist_user_list[a2 - num_user]), int(0.2 * len(artist_user_list[a2 - num_user])), replace=False)\n",
    "    candidate_u2_list = artist_user_list[a2 - num_user][candidate_u2_list]\n",
    "    u_a_t_a_u.extend([(u1, a1, t, a2, u2) for u1 in candidate_u1_list for u2 in candidate_u2_list])\n",
    "u_a_t_a_u = np.array(u_a_t_a_u)\n",
    "sorted_index = sorted(list(range(len(u_a_t_a_u))), key=lambda i : u_a_t_a_u[i, [0, 4, 1, 2, 3]].tolist())\n",
    "u_a_t_a_u = u_a_t_a_u[sorted_index]\n",
    "\n",
    "# 0-0\n",
    "u_u = user_friend.to_numpy(dtype=np.int32) - 1\n",
    "sorted_index = sorted(list(range(len(u_u))), key=lambda i : u_u[i].tolist())\n",
    "u_u = u_u[sorted_index]\n",
    "\n",
    "# 1-0-1\n",
    "a_u_a = []\n",
    "for u, a_list in user_artist_list.items():\n",
    "    a_u_a.extend([(a1, u, a2) for a1 in a_list for a2 in a_list])\n",
    "a_u_a = np.array(a_u_a)\n",
    "a_u_a[:, [0, 2]] += num_user\n",
    "sorted_index = sorted(list(range(len(a_u_a))), key=lambda i : a_u_a[i, [0, 2, 1]].tolist())\n",
    "a_u_a = a_u_a[sorted_index]\n",
    "\n",
    "# 1-0-0-1\n",
    "a_u_u_a = []\n",
    "for u1, u2 in u_u:\n",
    "    a_u_u_a.extend([(a1, u1, u2, a2) for a1 in user_artist_list[u1] for a2 in user_artist_list[u2]])\n",
    "a_u_u_a = np.array(a_u_u_a)\n",
    "a_u_u_a[:, [0, 3]] += num_user\n",
    "sorted_index = sorted(list(range(len(a_u_u_a))), key=lambda i : a_u_u_a[i, [0, 3, 1, 2]].tolist())\n",
    "a_u_u_a = a_u_u_a[sorted_index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "expected_metapaths = [\n",
    "    [(0, 1, 0), (0, 1, 2, 1, 0), (0, 0)],\n",
    "    [(1, 0, 1), (1, 2, 1), (1, 0, 0, 1)]\n",
    "]\n",
    "# create the directories if they do not exist\n",
    "for i in range(len(expected_metapaths)):\n",
    "    pathlib.Path(save_prefix + '{}'.format(i)).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "metapath_indices_mapping = {(0, 1, 0): u_a_u,\n",
    "                            (0, 1, 2, 1, 0): u_a_t_a_u,\n",
    "                            (0, 0): u_u,\n",
    "                            (1, 0, 1): a_u_a,\n",
    "                            (1, 2, 1): a_t_a,\n",
    "                            (1, 0, 0, 1): a_u_u_a}\n",
    "\n",
    "# write all things\n",
    "target_idx_lists = [np.arange(num_user), np.arange(num_artist)]\n",
    "offset_list = [0, num_user]\n",
    "for i, metapaths in enumerate(expected_metapaths):\n",
    "    for metapath in metapaths:\n",
    "        edge_metapath_idx_array = metapath_indices_mapping[metapath]\n",
    "        \n",
    "        with open(save_prefix + '{}/'.format(i) + '-'.join(map(str, metapath)) + '_idx.pickle', 'wb') as out_file:\n",
    "            target_metapaths_mapping = {}\n",
    "            left = 0\n",
    "            right = 0\n",
    "            for target_idx in target_idx_lists[i]:\n",
    "                while right < len(edge_metapath_idx_array) and edge_metapath_idx_array[right, 0] == target_idx + offset_list[i]:\n",
    "                    right += 1\n",
    "                target_metapaths_mapping[target_idx] = edge_metapath_idx_array[left:right, ::-1]\n",
    "                left = right\n",
    "            pickle.dump(target_metapaths_mapping, out_file)\n",
    "\n",
    "        #np.save(save_prefix + '{}/'.format(i) + '-'.join(map(str, metapath)) + '_idx.npy', edge_metapath_idx_array)\n",
    "        \n",
    "        with open(save_prefix + '{}/'.format(i) + '-'.join(map(str, metapath)) + '.adjlist', 'w') as out_file:\n",
    "            left = 0\n",
    "            right = 0\n",
    "            for target_idx in target_idx_lists[i]:\n",
    "                while right < len(edge_metapath_idx_array) and edge_metapath_idx_array[right, 0] == target_idx + offset_list[i]:\n",
    "                    right += 1\n",
    "                neighbors = edge_metapath_idx_array[left:right, -1] - offset_list[i]\n",
    "                neighbors = list(map(str, neighbors))\n",
    "                if len(neighbors) > 0:\n",
    "                    out_file.write('{} '.format(target_idx) + ' '.join(neighbors) + '\\n')\n",
    "                else:\n",
    "                    out_file.write('{}\\n'.format(target_idx))\n",
    "                left = right\n",
    "\n",
    "scipy.sparse.save_npz(save_prefix + 'adjM.npz', scipy.sparse.csr_matrix(adjM))\n",
    "np.save(save_prefix + 'node_types.npy', type_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# output user_artist.npy\n",
    "user_artist = pd.read_csv('data/raw/LastFM_magnn/user_artist.dat', encoding='utf-8', delimiter='\\t', names=['userID', 'artistID', 'weight'])\n",
    "user_artist = user_artist[['userID', 'artistID']].to_numpy()\n",
    "user_artist = user_artist - 1\n",
    "np.save(save_prefix + 'user_artist.npy', user_artist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# output positive and negative samples for training, validation and testing\n",
    "\n",
    "np.random.seed(453289)\n",
    "save_prefix = 'data/preprocessed/LastFM_magnn_processed/'\n",
    "num_user = 1892\n",
    "num_artist = 17632\n",
    "user_artist = np.load('data/preprocessed/LastFM_magnn_processed/user_artist.npy')\n",
    "train_val_test_idx = np.load('data/raw/LastFM_magnn/train_val_test_idx.npz')\n",
    "train_idx = train_val_test_idx['train_idx']\n",
    "val_idx = train_val_test_idx['val_idx']\n",
    "test_idx = train_val_test_idx['test_idx']\n",
    "\n",
    "neg_candidates = []\n",
    "counter = 0\n",
    "for i in range(num_user):\n",
    "    for j in range(num_artist):\n",
    "        if counter < len(user_artist):\n",
    "            if i == user_artist[counter, 0] and j == user_artist[counter, 1]:\n",
    "                counter += 1\n",
    "            else:\n",
    "                neg_candidates.append([i, j])\n",
    "        else:\n",
    "            neg_candidates.append([i, j])\n",
    "neg_candidates = np.array(neg_candidates)\n",
    "\n",
    "idx = np.random.choice(len(neg_candidates), len(val_idx) + len(test_idx), replace=False)\n",
    "val_neg_candidates = neg_candidates[sorted(idx[:len(val_idx)])]\n",
    "test_neg_candidates = neg_candidates[sorted(idx[len(val_idx):])]\n",
    "\n",
    "train_user_artist = user_artist[train_idx]\n",
    "train_neg_candidates = []\n",
    "counter = 0\n",
    "for i in range(num_user):\n",
    "    for j in range(num_artist):\n",
    "        if counter < len(train_user_artist):\n",
    "            if i == train_user_artist[counter, 0] and j == train_user_artist[counter, 1]:\n",
    "                counter += 1\n",
    "            else:\n",
    "                train_neg_candidates.append([i, j])\n",
    "        else:\n",
    "            train_neg_candidates.append([i, j])\n",
    "train_neg_candidates = np.array(train_neg_candidates)\n",
    "\n",
    "np.savez(save_prefix + 'train_val_test_neg_user_artist.npz',\n",
    "         train_neg_user_artist=train_neg_candidates,\n",
    "         val_neg_user_artist=val_neg_candidates,\n",
    "         test_neg_user_artist=test_neg_candidates)\n",
    "np.savez(save_prefix + 'train_val_test_pos_user_artist.npz',\n",
    "         train_pos_user_artist=user_artist[train_idx],\n",
    "         val_pos_user_artist=user_artist[val_idx],\n",
    "         test_pos_user_artist=user_artist[test_idx])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
